package com.boco.common.search;

import com.boco.domain.ExportSupportJpaRepository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.rest.webmvc.RootResourceInformation;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;
import java.lang.reflect.InvocationTargetException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;

/**
 * @author pandengke/pdkkpdk@163.com
 * @date 2018/9/30
 */
@Component
public class JpqlGeneratorImpl implements JpqlGenerator {

    private static Logger logger = LoggerFactory.getLogger(JpqlGeneratorImpl.class);

    private boolean clearEmptyKey = true;

    public List<Part> parse(Map<String, String> parameters, Class<?> domainClass) {
        if (clearEmptyKey) {
            clearEmptyKey(parameters);
        }
        List<Part> list = new ArrayList<>();
        parameters.forEach((k, v) -> {
            String field = checkKey(k, domainClass);
//            检查到不能识别的字段,则直接跳过
            if (StringUtils.isEmpty(field)) {
                return;
            }
            Operators flag = getFlag(k, field);
            //确定字段是否存在,防止sql注入
            Class<?> type = ReflectionUtils.findField(domainClass, field).getType();
            list.add(new Part(field, flag, v, type));
        });
        return list;
    }

    private String checkKey(String key, Class<?> domainClass) {
        int total = 0;
        for (char c : key.toCharArray()) {
            if (c == '_') {
                total++;
            }
        }
        if (total > 1) {
            throw new IllegalArgumentException("searck key cannot contains char '_' more than 2!");
        }
        String field = key.split("_")[0];
        if (ReflectionUtils.findField(domainClass, field) == null) {
            logger.error("the class " + domainClass + " has no field " + field);
            return null;
        }
        return field;
    }

    private Operators getFlag(String key, String field) {
        if (!key.contains("_")) {
            return Operators.EQUAL;
        }
        return Operators.fromFlag(key.replace(field, ""));
    }

    private void clearEmptyKey(Map<String, String> parameters) {
        parameters.keySet().removeIf(f ->
                StringUtils.isEmpty(parameters.get(f)) ||
                        "page".equals(f) ||
                        "size".equals(f) ||
                        "sort".equals(f)
        );
    }

    @Override
    public Specification geneSpecification(Map<String, String> parameters, Class<?> domainClass, boolean allMatch) {
        List<Part> parts = parse(parameters, domainClass);
        if (parts == null || parts.isEmpty()) {
            return null;
        }
        return (root, query, cb) -> {
            Predicate[] array = parts.stream().map(part -> partToPredicate(part, root, cb)).toArray(Predicate[]::new);
            return allMatch ? cb.and(array) : cb.or(array);
        };
    }

    private Predicate partToPredicate(Part part, Root root, CriteriaBuilder cb) {
        Object value = parseValue(part);
        Path path = root.get(part.getField());
        switch (part.getFlag()) {
            case EQUAL:
                return cb.equal(path, value);
            case GTE:
                return cb.greaterThanOrEqualTo(path, (Comparable) value);
            case LTE:
                return cb.lessThanOrEqualTo(path, (Comparable) value);
            case LIKE:
                return cb.like(path, (String) value);
            case NE:
                return cb.notEqual(path, value);
        }
        return cb.equal(path, value);
    }

    private Object parseValue(Part part) {
        Object value = part.outValue();
        if (part.getType().isEnum()) {
            value = Enum.valueOf((Class<Enum>) part.getType(), part.getValue());
        } else if (part.getType().isAssignableFrom(Date.class)) {
            if (!StringUtils.isEmpty(value)) {
                try {
                    value = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss").parse(value + "");
                } catch (ParseException e) {
                    e.printStackTrace();
                }
            }
        }
        return value;
    }


    @Override
    public Page<Object> execQuery(ExportSupportJpaRepository repository1, RootResourceInformation resourceInformation, Map<String, String> params, String repository, Pageable pageable, boolean allMatch) throws InvocationTargetException, IllegalAccessException {
        Specification specification = geneSpecification(params, resourceInformation.getDomainType(), allMatch);
        return (Page<Object>) repository1.findAll(specification, pageable);
    }
}
